from __future__ import absolute_import

import logging
import traceback
import binascii

from Crypto.Cipher import AES
from Crypto import Random


logger = logging.getLogger(__name__)


def encrypt_aes256gcm(plain_text, key, tag_len):
    """
    use scc to encrypt and decrypt
    :param plain_text: str
    :param key: base64 str
    :param tag_len: int,must be in the range 4..16
    :return: base64 str, crypt_text and tag, iv,
    """
    try:
        nonce = Random.get_random_bytes(32)
        key = binascii.a2b_base64(key)
        cipher = AES.new(key=key, mode=AES.MODE_GCM,
                         nonce=nonce, mac_len=tag_len)

        if isinstance(plain_text, bytes):
            crypt_b = cipher.encrypt(plain_text)
        elif isinstance(plain_text, str):
            crypt_b = cipher.encrypt(plain_text.encode("utf-8"))
        else:
            raise Exception(
                "type of param plain_text error, must be str or bytes")

        digest_b = cipher.digest()

        return binascii.b2a_base64(crypt_b), \
            binascii.b2a_base64(digest_b), \
            binascii.b2a_base64(nonce)

    except Exception as e:
        logger.error("aes256gcm encrypt fail, "
                     "exception： %s", traceback.format_exc(e))
        return '', '', ''


def decrypt_aes256gcm(crypt_text, key, iv, tag):
    """
    use scc to decrypt
    :param crypt_text: str str
    :param key: base64 str
    :param iv: base64 str
    :param tag: base64 str
    :return: plain_text str
    """
    try:
        mac = binascii.a2b_base64(tag)
        nonce = binascii.a2b_base64(iv)
        key = binascii.a2b_base64(key)
        crypt_text = binascii.a2b_base64(crypt_text)
        cipher = AES.new(key=key, mode=AES.MODE_GCM,
                         nonce=nonce, mac_len=len(mac))
        return cipher.decrypt_and_verify(crypt_text, mac).decode("utf-8")
    except Exception as e:
        logger.error("aes256gcm decrypt fail, "
                     "exception：%s", traceback.format_exc(e))
        return ''


if __name__ == '__main__':
    # tmp_key = binascii.b2a_base64(Random.get_random_bytes(32))
    from conf.conf import aes_gcm_key
    tmp_key = aes_gcm_key
    # test encrypt
    crypt_t, tag, iv = encrypt_aes256gcm("******", key=tmp_key, tag_len=16)
    print("crypt_text：", crypt_t)
    print("tag：", tag)
    print("iv：", iv)

    # test decrypt
    plain_t = decrypt_aes256gcm(crypt_t, tmp_key, iv, tag)
    print("plain_text:", plain_t)
